{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Please ignore these variable, they only provide options for our CI system.\n",
    "args = []\n",
    "abort_after_one = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial: Federated learning with websockets and federated averaging with possible solutions for problem you might face\n",
    "\n",
    "This notebook will discuss detailed steps and problems you might face when going through these steps\n",
    "\n",
    "Make sure you have correct websocket-client library because if you have another websocket library installed on top of websocket-client when you run this command ``` import websocket ``` it try will access that additional websocket library first because websocket-client is also called imported into your python script by ``` import websocket ``` and when you try to create connection with this command ``` websocket.create_connection() ``` this causes websocket don't have any module named create_connection\n",
    "Solution: in terminal activate that environment where syft is installed run ```pip uninstall websocket``` to remove any additional websocket libraries then run ```pip install --upgrade websocket_client```\n",
    "\n",
    "Authors:\n",
    "- midokura-silvia\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparation: start the websocket server workers\n",
    "\n",
    "Each worker is represented by two parts, a local handle (websocket client worker) and the remote instance that holds the data and performs the computations. The remote part is called a websocket server worker.\n",
    "\n",
    "So first, you need to ```cd``` to the folder where this notebook and other additional files for running server and client are \n",
    "\n",
    "for example\n",
    "in windows 10  \n",
    ">cd (path till projects directory) \\python_projects\\websockets-example-MNIST\n",
    "\n",
    "Note: Don't copy paste the path above because this is purely for the sake example your path may differ depending on your OS and project folder\n",
    " \n",
    "\n",
    "\n",
    "because if you don't when you try to run ```python start_websocket_servers.py``` command in terminal this script open sub processes with python which runs other scripts that starts websocket server workers and only the name of the file with its extension is mentioned because the file's path may vary.\n",
    "we need to create the remote workers. For this, you need to run in a terminal (not possible from the notebook):\n",
    "\n",
    "```bash\n",
    "python start_websocket_servers.py\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up the websocket client workers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first need to perform the imports and setup some arguments and variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import syft as sy\n",
    "from syft.workers.websocket_client import WebsocketClientWorker\n",
    "import torch\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "from syft.frameworks.torch.fl import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import run_websocket_client as rwc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = rwc.define_and_get_arguments(args=args)\n",
    "use_cuda = args.cuda and torch.cuda.is_available()\n",
    "torch.manual_seed(args.seed)\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "print(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's instantiate the websocket client workers, our local access point to the remote workers.\n",
    "Note that **this step will fail, if the websocket server workers are not running**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hook = sy.TorchHook(torch)\n",
    "\n",
    "kwargs_websocket = {\"host\": \"localhost\", \"hook\": hook, \"verbose\": args.verbose}\n",
    "alice = WebsocketClientWorker(id=\"alice\", port=8777, **kwargs_websocket)\n",
    "bob = WebsocketClientWorker(id=\"bob\", port=8778, **kwargs_websocket)\n",
    "charlie = WebsocketClientWorker(id=\"charlie\", port=8779, **kwargs_websocket)\n",
    "\n",
    "workers = [alice, bob, charlie]\n",
    "print(workers)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare and distribute the training data\n",
    "\n",
    "We will use the MNIST dataset and distribute the data randomly onto the workers. \n",
    "This is not realistic for a federated training setup, where the data would normally already be available at the remote workers.\n",
    "\n",
    "We instantiate two FederatedDataLoaders, one for the train and one for the test set of the MNIST dataset.\n",
    "\n",
    "*If you run into BrokenPipe errors go to the parrent directory of the directory where your project is and delete data folder then restart notebook and try again if the error comes again delete that data folder again run the following command*\n",
    "\n",
    "for example directory for data \n",
    "\n",
    ">(path till projects directory) \\python_projects\\\n",
    "\n",
    "directory for project notebook and scripts\n",
    "\n",
    ">(path till projects directory) \\python_projects\\websockets-example-MNIST\n",
    "\n",
    "Note: Don't copy paste the path above because this is purely for the sake example your path may differ depending on your OS and project folder\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#run this box only if the the next box gives pipeline error\n",
    "torch.utils.data.DataLoader(\n",
    "    datasets.MNIST(\n",
    "        \"../data\",\n",
    "        train=True,download=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "federated_train_loader = sy.FederatedDataLoader(\n",
    "    datasets.MNIST(\n",
    "        \"../data\",\n",
    "        train=True,\n",
    "        download=True,\n",
    "        transform=transforms.Compose(\n",
    "            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
    "        ),\n",
    "    ).federate(tuple(workers)),\n",
    "    batch_size=args.batch_size,\n",
    "    shuffle=True,\n",
    "    iter_per_worker=True\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST(\n",
    "        \"../data\",\n",
    "        train=False,\n",
    "        transform=transforms.Compose(\n",
    "            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
    "        ),\n",
    "    ),\n",
    "    batch_size=args.test_batch_size,\n",
    "    shuffle=True\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we need to instantiate the machine learning model. It is a small neural network with 2 convolutional and two fully connected layers. \n",
    "It uses ReLU activations and max pooling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = rwc.Net().to(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import sys\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.DEBUG)\n",
    "handler = logging.StreamHandler(sys.stderr)\n",
    "formatter = logging.Formatter(\"%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s\")\n",
    "handler.setFormatter(formatter)\n",
    "logger.handlers = [handler]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's start the training\n",
    "\n",
    "\n",
    "Now we are ready to start the federated training. We will perform training over a given number of batches separately on each worker and then calculate the federated average of the resulting model and calculate test accuracy over that model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for epoch in range(1, args.epochs + 1):\n",
    "    print(\"Starting epoch {}/{}\".format(epoch, args.epochs))\n",
    "    model = rwc.train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches, \n",
    "                      abort_after_one=abort_after_one)\n",
    "    rwc.test(model, device, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for \"Projects\". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more \"one off\" mini-projects by searching for GitHub issues marked \"good first issue\".\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
